import torch
import tilelang
import tilelang.testing
from tilelang import language as T


@tilelang.jit(
    pass_configs={
        tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True,
        tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
    },)
def _cumsum_view_infer_layout(hidden):
    num_tokens = T.dynamic('num_tokens')

    @T.prim_func
    def buggy_kernel(x: T.Tensor[(num_tokens, hidden), 'float']):
        with T.Kernel(num_tokens, threads=128) as pid:
            smem = T.alloc_shared((hidden,), dtype='float')
            T.copy(x[pid, :], smem)
            T.cumsum(T.view(smem, (1, hidden)), dim=1)

    return buggy_kernel


def test_cumsum_view_infer_layout():
    hidden = 128
    x = torch.randn(1, hidden, device='cuda', dtype=torch.float)
    kernel = _cumsum_view_infer_layout(hidden)
    kernel(x)


if __name__ == '__main__':
    tilelang.testing.main()
